#!/usr/bin/env python
# -*- coding:utf-8 -*-
from abc import ABCMeta

import six

from util import load_tokenizer, load_model, tokenize
from util.jb import fc

__all__ = ['ModelFactory', 'Model']


class ModelFactory(object):
    """Model工厂：用于提供Model对象"""

    @staticmethod
    def get_model(sub_cls, update=False):
        if not issubclass(sub_cls, Model):
            raise SyntaxError('%s not subclass for %s' % (sub_cls.__name__, Model.__name__))
        if update:
            sub_cls._instance._update()
        return sub_cls()


@six.add_metaclass(ABCMeta)
class Model(object):
    """Model:包含word2var和LSTM模型"""
    tokenize_path = ''
    model_path = ''
    weights_path = ''
    max_len = None
    _instance = None

    def __new__(cls, *args, **kwargs):
        """单例模式"""
        if not cls._instance:
            cls._instance = super(Model, cls).__new__(cls, *args, **kwargs)
        return cls._instance

    def __init__(self):
        super(Model, self).__init__()
        self._update()

    def predict(self, texts):
        """对文本进行分类"""
        texts = list(map(fc, texts))
        combined = tokenize(self.__t, texts, self.max_len)
        label = self.__m.predict_classes(combined)
        return label

    def _update(self):
        """更新model"""
        self.__m = load_model(self.model_path, self.weights_path)
        self.__t = load_tokenizer(self.tokenize_path) 


if __name__ == '__main__':
    model = ModelFactory.get_model(Model)
